/*
 * Copyright (C) 2023 Coder.AN
 * Email: an.hongjun@foxmail.com
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */
#include "edge_tensor.h"
#include "edge_tensor_kernel.h"
#include <stdio.h>

namespace flip{
__device__ void cal_dim(int idx, int* n, int* h, int* w, int*c, Dim dim, DimOrder order)
{
    if (order == NHWC)
    {
        *n = idx / (dim.H * dim.W * dim.C);
        idx -= *n * dim.H * dim.W * dim.C;
        *h = idx / (dim.W * dim.C);
        idx -= *h * dim.W * dim.C;
        *w = idx / dim.C;
        idx -= *w * dim.C;
        *c = idx;
    }
    else if (order == NCWH)
    {
        *n = idx / (dim.C * dim.W * dim.H);
        idx -= *n * dim.C * dim.W * dim.H;
        *c = idx / (dim.W * dim.H);
        idx -= *c * dim.W * dim.H;
        *w = idx / dim.H;
        idx -= *w * dim.H;
        *h = idx;
    }
    else if (order == NCHW)
    {
        *n = idx / (dim.C * dim.H * dim.W);
        idx -= *n * dim.C * dim.H * dim.W;
        *c = idx / (dim.H * dim.W);
        idx -= *c * dim.H * dim.W;
        *h = idx / dim.W;
        idx -= *h * dim.W;
        *w = idx;
    }
}

__device__ int cal_new_idx(int n, int h, int w, int c, Dim dim, DimOrder order)
{
    if (order == NHWC)
    {
        return n * dim.H * dim.W * dim.C + h * dim.W * dim.C + w * dim.C + c;
    }
    else if (order == NCWH)
    {
        return n * dim.C * dim.W * dim.H + c * dim.W * dim.H + w * dim.H + h;
    }
    else if (order == NCHW)
    {
        return n * dim.C * dim.H * dim.W + c * dim.H * dim.W + h * dim.W + w;
    }
    return 0;
}
}

__global__ void VFlip_kernel(const void* src, void* dst, Dim dim, DimOrder dim_order, 
                                size_t unit_size, size_t unit_count)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= unit_count) return;

    int n, h, w, c;
    flip::cal_dim(idx, &n, &h, &w, &c, dim, dim_order);

    int old_idx = flip::cal_new_idx(n, dim.H - 1 - h, w, c, dim, dim_order);
    
    if (unit_size == 1)
    {
        ((uint8_t*)dst)[idx] = ((uint8_t*)src)[old_idx];
    }
    else if (unit_size == 2)
    {
        ((uint16_t*)dst)[idx] = ((uint16_t*)src)[old_idx];
    }
    else if (unit_size == 4)
    {
        ((uint32_t*)dst)[idx] = ((uint32_t*)src)[old_idx];
    }
    else if (unit_size == 8)
    {
        ((uint64_t*)dst)[idx] = ((uint64_t*)src)[old_idx];
    }


}

void VFlip(const void* src, void* dst, Dim dim, DimOrder dim_order, 
            size_t unit_size, size_t unit_count)
{
    dim3 block(1024);
    dim3 grid((unit_count - 1) / block.x + 1);
    VFlip_kernel<<<grid, block>>>(src, dst, dim, dim_order, unit_size, unit_count);
    CHECK(cudaDeviceSynchronize());
    return;
}

__global__ void HFlip_kernel(const void* src, void* dst, Dim dim, DimOrder dim_order, 
                                size_t unit_size, size_t unit_count)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= unit_count) return;

    int n, h, w, c;
    flip::cal_dim(idx, &n, &h, &w, &c, dim, dim_order);

    int old_idx = flip::cal_new_idx(n, h, dim.W - 1 - w, c, dim, dim_order);
    
    if (unit_size == 1)
    {
        ((uint8_t*)dst)[idx] = ((uint8_t*)src)[old_idx];
    }
    else if (unit_size == 2)
    {
        ((uint16_t*)dst)[idx] = ((uint16_t*)src)[old_idx];
    }
    else if (unit_size == 4)
    {
        ((uint32_t*)dst)[idx] = ((uint32_t*)src)[old_idx];
    }
    else if (unit_size == 8)
    {
        ((uint64_t*)dst)[idx] = ((uint64_t*)src)[old_idx];
    }


}

void HFlip(const void* src, void* dst, Dim dim, DimOrder dim_order, 
            size_t unit_size, size_t unit_count)
{
    dim3 block(1024);
    dim3 grid((unit_count - 1) / block.x + 1);
    HFlip_kernel<<<grid, block>>>(src, dst, dim, dim_order, unit_size, unit_count);
    CHECK(cudaDeviceSynchronize());
    return;
}